7.seq2seq

第六课 Seq2Seq, Attention

褚则伟 zeweichu@gmail.com

在这份notebook当中,我们会(尽可能)复现Luong的attention模型

由于我们的数据集非常小,只有一万多个句子的训练数据,所以训练出来的模型效果并不好。如果大家想训练一个好一点的模型,可以参考下面的资料。

更多阅读

课件

论文

PyTorch代码

更多关于Machine Translation

  • Beam Search
  • Pointer network 文本摘要
  • Copy Mechanism 文本摘要
  • Converage Loss
  • ConvSeq2Seq
  • Transformer
  • Tensor2Tensor

TODO

  • 建议同学尝试对中文进行分词

NER

1
2
3
4
5
6
7
8
9
10
11
12
import os
import sys
import math
from collections import Counter
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

import nltk

读入中英文数据

  • 英文我们使用nltk的word tokenizer来分词,并且使用小写字母
  • 中文我们直接使用单个汉字作为基本单元
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def load_data(in_file):
cn = []
en = []
num_examples = 0
with open(in_file, 'r') as f:
for line in f:
line = line.strip().split("\t")

en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
# split chinese sentence into characters
cn.append(["BOS"] + [c for c in line[1]] + ["EOS"])
return en, cn

train_file = "nmt/en-cn/train.txt"
dev_file = "nmt/en-cn/dev.txt"
train_en, train_cn = load_data(train_file)
dev_en, dev_cn = load_data(dev_file)

构建单词表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
UNK_IDX = 0
PAD_IDX = 1
def build_dict(sentences, max_words=50000):
word_count = Counter()
for sentence in sentences:
for s in sentence:
word_count[s] += 1
ls = word_count.most_common(max_words)
total_words = len(ls) + 2
word_dict = {w[0]: index+2 for index, w in enumerate(ls)}
word_dict["UNK"] = UNK_IDX
word_dict["PAD"] = PAD_IDX
return word_dict, total_words

en_dict, en_total_words = build_dict(train_en)
cn_dict, cn_total_words = build_dict(train_cn)
inv_en_dict = {v: k for k, v in en_dict.items()}
inv_cn_dict = {v: k for k, v in cn_dict.items()}

把单词全部转变成数字

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def encode(en_sentences, cn_sentences, en_dict, cn_dict, sort_by_len=True):
'''
Encode the sequences.
'''
length = len(en_sentences)
out_en_sentences = [[en_dict.get(w, 0) for w in sent] for sent in en_sentences]
out_cn_sentences = [[cn_dict.get(w, 0) for w in sent] for sent in cn_sentences]

# sort sentences by english lengths
def len_argsort(seq):
return sorted(range(len(seq)), key=lambda x: len(seq[x]))

# 把中文和英文按照同样的顺序排序
if sort_by_len:
sorted_index = len_argsort(out_en_sentences)
out_en_sentences = [out_en_sentences[i] for i in sorted_index]
out_cn_sentences = [out_cn_sentences[i] for i in sorted_index]

return out_en_sentences, out_cn_sentences

train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict)
dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)
1
2
3
4
# train_cn[:10]
k = 10000
print(" ".join([inv_cn_dict[i] for i in train_cn[k]]))
print(" ".join([inv_en_dict[i] for i in train_en[k]]))
BOS 他 来 这 里 的 目 的 是 什 么 ? EOS
BOS for what purpose did he come here ? EOS

把全部句子分成batch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def get_minibatches(n, minibatch_size, shuffle=True):
idx_list = np.arange(0, n, minibatch_size) # [0, 1, ..., n-1]
if shuffle:
np.random.shuffle(idx_list)
minibatches = []
for idx in idx_list:
minibatches.append(np.arange(idx, min(idx + minibatch_size, n)))
return minibatches

def prepare_data(seqs):
lengths = [len(seq) for seq in seqs]
n_samples = len(seqs)
max_len = np.max(lengths)

x = np.zeros((n_samples, max_len)).astype('int32')
x_lengths = np.array(lengths).astype("int32")
for idx, seq in enumerate(seqs):
x[idx, :lengths[idx]] = seq
return x, x_lengths #x_mask

def gen_examples(en_sentences, cn_sentences, batch_size):
minibatches = get_minibatches(len(en_sentences), batch_size)
all_ex = []
for minibatch in minibatches:
mb_en_sentences = [en_sentences[t] for t in minibatch]
mb_cn_sentences = [cn_sentences[t] for t in minibatch]
mb_x, mb_x_len = prepare_data(mb_en_sentences)
mb_y, mb_y_len = prepare_data(mb_cn_sentences)
all_ex.append((mb_x, mb_x_len, mb_y, mb_y_len))
return all_ex

batch_size = 64
train_data = gen_examples(train_en, train_cn, batch_size)
random.shuffle(train_data)
dev_data = gen_examples(dev_en, dev_cn, batch_size)

没有Attention的版本

下面是一个更简单的没有Attention的encoder decoder模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class PlainEncoder(nn.Module):
def __init__(self, vocab_size, hidden_size, dropout=0.2):
super(PlainEncoder, self).__init__()
self.embed = nn.Embedding(vocab_size, hidden_size)
self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
self.dropout = nn.Dropout(dropout)

def forward(self, x, lengths):
sorted_len, sorted_idx = lengths.sort(0, descending=True)
x_sorted = x[sorted_idx.long()]
embedded = self.dropout(self.embed(x_sorted))

packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
packed_out, hid = self.rnn(packed_embedded)
out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
out = out[original_idx.long()].contiguous()
hid = hid[:, original_idx.long()].contiguous()

return out, hid[[-1]]

class PlainDecoder(nn.Module):
def __init__(self, vocab_size, hidden_size, dropout=0.2):
super(PlainDecoder, self).__init__()
self.embed = nn.Embedding(vocab_size, hidden_size)
self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
self.out = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout)

def forward(self, y, y_lengths, hid):
sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
y_sorted = y[sorted_idx.long()]
hid = hid[:, sorted_idx.long()]

y_sorted = self.dropout(self.embed(y_sorted)) # batch_size, output_length, embed_size

packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
out, hid = self.rnn(packed_seq, hid)
unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
output_seq = unpacked[original_idx.long()].contiguous()
# print(output_seq.shape)
hid = hid[:, original_idx.long()].contiguous()

output = F.log_softmax(self.out(output_seq), -1)

return output, hid

class PlainSeq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super(PlainSeq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder

def forward(self, x, x_lengths, y, y_lengths):
encoder_out, hid = self.encoder(x, x_lengths)
output, hid = self.decoder(y=y,
y_lengths=y_lengths,
hid=hid)
return output, None

def translate(self, x, x_lengths, y, max_length=10):
encoder_out, hid = self.encoder(x, x_lengths)
preds = []
batch_size = x.shape[0]
attns = []
for i in range(max_length):
output, hid = self.decoder(y=y,
y_lengths=torch.ones(batch_size).long().to(y.device),
hid=hid)
y = output.max(2)[1].view(batch_size, 1)
preds.append(y)

return torch.cat(preds, 1), None
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# masked cross entropy loss
class LanguageModelCriterion(nn.Module):
def __init__(self):
super(LanguageModelCriterion, self).__init__()

def forward(self, input, target, mask):
# input: (batch_size * seq_len) * vocab_size
input = input.contiguous().view(-1, input.size(2))
# target: batch_size * 1
target = target.contiguous().view(-1, 1)
mask = mask.contiguous().view(-1, 1)
output = -input.gather(1, target) * mask
output = torch.sum(output) / torch.sum(mask)

return output
1
2
3
4
5
6
7
8
9
10
11
12
13
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dropout = 0.2
hidden_size = 100
encoder = PlainEncoder(vocab_size=en_total_words,
hidden_size=hidden_size,
dropout=dropout)
decoder = PlainDecoder(vocab_size=cn_total_words,
hidden_size=hidden_size,
dropout=dropout)
model = PlainSeq2Seq(encoder, decoder)
model = model.to(device)
loss_fn = LanguageModelCriterion().to(device)
optimizer = torch.optim.Adam(model.parameters())
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def evaluate(model, data):
model.eval()
total_num_words = total_loss = 0.
with torch.no_grad():
for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
mb_x = torch.from_numpy(mb_x).to(device).long()
mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
mb_y_len[mb_y_len<=0] = 1

mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)

mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
mb_out_mask = mb_out_mask.float()

loss = loss_fn(mb_pred, mb_output, mb_out_mask)

num_words = torch.sum(mb_y_len).item()
total_loss += loss.item() * num_words
total_num_words += num_words
print("Evaluation loss", total_loss/total_num_words)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def train(model, data, num_epochs=20):
for epoch in range(num_epochs):
model.train()
total_num_words = total_loss = 0.
for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
mb_x = torch.from_numpy(mb_x).to(device).long()
mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
mb_y_len[mb_y_len<=0] = 1

mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)

mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
mb_out_mask = mb_out_mask.float()

loss = loss_fn(mb_pred, mb_output, mb_out_mask)

num_words = torch.sum(mb_y_len).item()
total_loss += loss.item() * num_words
total_num_words += num_words

# 更新模型
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
optimizer.step()

if it % 100 == 0:
print("Epoch", epoch, "iteration", it, "loss", loss.item())


print("Epoch", epoch, "Training loss", total_loss/total_num_words)
if epoch % 5 == 0:
evaluate(model, dev_data)
train(model, train_data, num_epochs=20)
Epoch 0 iteration 0 loss 8.050323486328125
Epoch 0 iteration 100 loss 5.278979301452637
Epoch 0 iteration 200 loss 4.444733619689941
Epoch 0 Training loss 5.433318799975385
Evaluation loss 4.822829000278033
Epoch 1 iteration 0 loss 4.692166805267334
Epoch 1 iteration 100 loss 4.708909511566162
Epoch 1 iteration 200 loss 3.8643922805786133
Epoch 1 Training loss 4.5993410716009135
Epoch 2 iteration 0 loss 4.17959451675415
Epoch 2 iteration 100 loss 4.352121829986572
Epoch 2 iteration 200 loss 3.5356297492980957
Epoch 2 Training loss 4.198561833806036
Epoch 3 iteration 0 loss 3.8728413581848145
Epoch 3 iteration 100 loss 4.134408950805664
Epoch 3 iteration 200 loss 3.303772211074829
Epoch 3 Training loss 3.9386860033522813
Epoch 4 iteration 0 loss 3.64646053314209
Epoch 4 iteration 100 loss 3.947233200073242
Epoch 4 iteration 200 loss 3.1333234310150146
Epoch 4 Training loss 3.745685762442693
Epoch 5 iteration 0 loss 3.481276035308838
Epoch 5 iteration 100 loss 3.827454090118408
Epoch 5 iteration 200 loss 2.9994454383850098
Epoch 5 Training loss 3.5913285724858954
Evaluation loss 3.6815984345855037
Epoch 6 iteration 0 loss 3.3354697227478027
Epoch 6 iteration 100 loss 3.6918392181396484
Epoch 6 iteration 200 loss 2.8618223667144775
Epoch 6 Training loss 3.465248799091302
Epoch 7 iteration 0 loss 3.2224643230438232
Epoch 7 iteration 100 loss 3.5980327129364014
Epoch 7 iteration 200 loss 2.783277988433838
Epoch 7 Training loss 3.357013859409834
Epoch 8 iteration 0 loss 3.141510248184204
Epoch 8 iteration 100 loss 3.5131657123565674
Epoch 8 iteration 200 loss 2.715005397796631
Epoch 8 Training loss 3.2614931554428166
Epoch 9 iteration 0 loss 3.0618908405303955
Epoch 9 iteration 100 loss 3.4437694549560547
Epoch 9 iteration 200 loss 2.5995192527770996
Epoch 9 Training loss 3.1806184197973404
Epoch 10 iteration 0 loss 2.9288880825042725
Epoch 10 iteration 100 loss 3.350996971130371
Epoch 10 iteration 200 loss 2.5103659629821777
Epoch 10 Training loss 3.101915731518774
Evaluation loss 3.393061912401112
Epoch 11 iteration 0 loss 2.874830722808838
Epoch 11 iteration 100 loss 3.3034920692443848
Epoch 11 iteration 200 loss 2.4885127544403076
Epoch 11 Training loss 3.0369929761565384
Epoch 12 iteration 0 loss 2.8056483268737793
Epoch 12 iteration 100 loss 3.2505335807800293
Epoch 12 iteration 200 loss 2.4071717262268066
Epoch 12 Training loss 2.973809002606383
Epoch 13 iteration 0 loss 2.7353591918945312
Epoch 13 iteration 100 loss 3.178480863571167
Epoch 13 iteration 200 loss 2.3422422409057617
Epoch 13 Training loss 2.9169208222083847
Epoch 14 iteration 0 loss 2.6794426441192627
Epoch 14 iteration 100 loss 3.129685401916504
Epoch 14 iteration 200 loss 2.3255887031555176
Epoch 14 Training loss 2.86419656519231
Epoch 15 iteration 0 loss 2.6482393741607666
Epoch 15 iteration 100 loss 3.0710315704345703
Epoch 15 iteration 200 loss 2.2372782230377197
Epoch 15 Training loss 2.8170104509222287
Evaluation loss 3.2708830728055336
Epoch 16 iteration 0 loss 2.567857503890991
Epoch 16 iteration 100 loss 3.0710268020629883
Epoch 16 iteration 200 loss 2.238800525665283
Epoch 16 Training loss 2.771683479683666
Epoch 17 iteration 0 loss 2.5122745037078857
Epoch 17 iteration 100 loss 3.002455472946167
Epoch 17 iteration 200 loss 2.1964993476867676
Epoch 17 Training loss 2.733348611161267
Epoch 18 iteration 0 loss 2.49585223197937
Epoch 18 iteration 100 loss 2.971094846725464
Epoch 18 iteration 200 loss 2.1383423805236816
Epoch 18 Training loss 2.6926882812821322
Epoch 19 iteration 0 loss 2.436241388320923
Epoch 19 iteration 100 loss 2.942230224609375
Epoch 19 iteration 200 loss 2.0685524940490723
Epoch 19 Training loss 2.6545419067862515
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def translate_dev(i):
en_sent = " ".join([inv_en_dict[w] for w in dev_en[i]])
print(en_sent)
cn_sent = " ".join([inv_cn_dict[w] for w in dev_cn[i]])
print("".join(cn_sent))

mb_x = torch.from_numpy(np.array(dev_en[i]).reshape(1, -1)).long().to(device)
mb_x_len = torch.from_numpy(np.array([len(dev_en[i])])).long().to(device)
bos = torch.Tensor([[cn_dict["BOS"]]]).long().to(device)

translation, attn = model.translate(mb_x, mb_x_len, bos)
translation = [inv_cn_dict[i] for i in translation.data.cpu().numpy().reshape(-1)]
trans = []
for word in translation:
if word != "EOS":
trans.append(word)
else:
break
print("".join(trans))

for i in range(100,120):
translate_dev(i)
print()
BOS you have nice skin . EOS
BOS 你 的 皮 膚 真 好 。 EOS
你必須吃。

BOS you 're UNK correct . EOS
BOS 你 部 分 正 确 。 EOS
你是一个好的。

BOS everyone admired his courage . EOS
BOS 每 個 人 都 佩 服 他 的 勇 氣 。 EOS
每个人都在学习。

BOS what time is it ? EOS
BOS 几 点 了 ? EOS
它什么是谁?

BOS i 'm free tonight . EOS
BOS 我 今 晚 有 空 。 EOS
我很快就會。

BOS here is your book . EOS
BOS 這 是 你 的 書 。 EOS
這是你的。

BOS they are at lunch . EOS
BOS 他 们 在 吃 午 饭 。 EOS
他们有个大学。

BOS this chair is UNK . EOS
BOS 這 把 椅 子 很 UNK 。 EOS
這個房間是一個人的。

BOS it 's pretty heavy . EOS
BOS 它 真 重 。 EOS
它是一個好的。

BOS many attended his funeral . EOS
BOS 很 多 人 都 参 加 了 他 的 葬 礼 。 EOS
許多的人都喜歡茶。

BOS training will be provided . EOS
BOS 会 有 训 练 。 EOS
要下雨。

BOS someone is watching you . EOS
BOS 有 人 在 看 著 你 。 EOS
有人是你的。

BOS i slapped his face . EOS
BOS 我 摑 了 他 的 臉 。 EOS
我認為他的手臂。

BOS i like UNK music . EOS
BOS 我 喜 歡 流 行 音 樂 。 EOS
我喜歡打棒球。

BOS tom had no children . EOS
BOS T o m 沒 有 孩 子 。 EOS
汤姆没有人。

BOS please lock the door . EOS
BOS 請 把 門 鎖 上 。 EOS
請把你的車。

BOS tom has calmed down . EOS
BOS 汤 姆 冷 静 下 来 了 。 EOS
汤姆在花園裡。

BOS please speak more loudly . EOS
BOS 請 說 大 聲 一 點 兒 。 EOS
請稍好喝咖啡。

BOS keep next sunday free . EOS
BOS 把 下 周 日 空 出 来 。 EOS
繼續工作很多。

BOS i made a mistake . EOS
BOS 我 犯 了 一 個 錯 。 EOS
我是一個小孩。

数据全部处理完成,现在我们开始构建seq2seq模型

Encoder

  • Encoder模型的任务是把输入文字传入embedding层和GRU层,转换成一些hidden states作为后续的context vectors
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Encoder(nn.Module):
def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
super(Encoder, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(embed_size, enc_hidden_size, batch_first=True, bidirectional=True)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(enc_hidden_size * 2, dec_hidden_size)

def forward(self, x, lengths):
sorted_len, sorted_idx = lengths.sort(0, descending=True)
x_sorted = x[sorted_idx.long()]
embedded = self.dropout(self.embed(x_sorted))

packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
packed_out, hid = self.rnn(packed_embedded)
out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
out = out[original_idx.long()].contiguous()
hid = hid[:, original_idx.long()].contiguous()

hid = torch.cat([hid[-2], hid[-1]], dim=1)
hid = torch.tanh(self.fc(hid)).unsqueeze(0)

return out, hid

Luong Attention

  • 根据context vectors和当前的输出hidden states,计算输出
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Attention(nn.Module):
def __init__(self, enc_hidden_size, dec_hidden_size):
super(Attention, self).__init__()

self.enc_hidden_size = enc_hidden_size
self.dec_hidden_size = dec_hidden_size

self.linear_in = nn.Linear(enc_hidden_size*2, dec_hidden_size, bias=False)
self.linear_out = nn.Linear(enc_hidden_size*2 + dec_hidden_size, dec_hidden_size)

def forward(self, output, context, mask):
# output: batch_size, output_len, dec_hidden_size
# context: batch_size, context_len, 2*enc_hidden_size

batch_size = output.size(0)
output_len = output.size(1)
input_len = context.size(1)

context_in = self.linear_in(context.view(batch_size*input_len, -1)).view(
batch_size, input_len, -1) # batch_size, context_len, dec_hidden_size

# context_in.transpose(1,2): batch_size, dec_hidden_size, context_len
# output: batch_size, output_len, dec_hidden_size
attn = torch.bmm(output, context_in.transpose(1,2))
# batch_size, output_len, context_len

attn.data.masked_fill(mask, -1e6)

attn = F.softmax(attn, dim=2)
# batch_size, output_len, context_len

context = torch.bmm(attn, context)
# batch_size, output_len, enc_hidden_size

output = torch.cat((context, output), dim=2) # batch_size, output_len, hidden_size*2

output = output.view(batch_size*output_len, -1)
output = torch.tanh(self.linear_out(output))
output = output.view(batch_size, output_len, -1)
return output, attn

Decoder

  • decoder会根据已经翻译的句子内容,和context vectors,来决定下一个输出的单词
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Decoder(nn.Module):
def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
super(Decoder, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.attention = Attention(enc_hidden_size, dec_hidden_size)
self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True)
self.out = nn.Linear(dec_hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout)

def create_mask(self, x_len, y_len):
# a mask of shape x_len * y_len
device = x_len.device
max_x_len = x_len.max()
max_y_len = y_len.max()
x_mask = torch.arange(max_x_len, device=x_len.device)[None, :] < x_len[:, None]
y_mask = torch.arange(max_y_len, device=x_len.device)[None, :] < y_len[:, None]
mask = (1 - x_mask[:, :, None] * y_mask[:, None, :]).byte()
return mask

def forward(self, ctx, ctx_lengths, y, y_lengths, hid):
sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
y_sorted = y[sorted_idx.long()]
hid = hid[:, sorted_idx.long()]

y_sorted = self.dropout(self.embed(y_sorted)) # batch_size, output_length, embed_size

packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
out, hid = self.rnn(packed_seq, hid)
unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
output_seq = unpacked[original_idx.long()].contiguous()
hid = hid[:, original_idx.long()].contiguous()

mask = self.create_mask(y_lengths, ctx_lengths)

output, attn = self.attention(output_seq, ctx, mask)
output = F.log_softmax(self.out(output), -1)

return output, hid, attn

Seq2Seq

  • 最后我们构建Seq2Seq模型把encoder, attention, decoder串到一起
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder

def forward(self, x, x_lengths, y, y_lengths):
encoder_out, hid = self.encoder(x, x_lengths)
output, hid, attn = self.decoder(ctx=encoder_out,
ctx_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
hid=hid)
return output, attn

def translate(self, x, x_lengths, y, max_length=100):
encoder_out, hid = self.encoder(x, x_lengths)
preds = []
batch_size = x.shape[0]
attns = []
for i in range(max_length):
output, hid, attn = self.decoder(ctx=encoder_out,
ctx_lengths=x_lengths,
y=y,
y_lengths=torch.ones(batch_size).long().to(y.device),
hid=hid)
y = output.max(2)[1].view(batch_size, 1)
preds.append(y)
attns.append(attn)
return torch.cat(preds, 1), torch.cat(attns, 1)

训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
dropout = 0.2
embed_size = hidden_size = 100
encoder = Encoder(vocab_size=en_total_words,
embed_size=embed_size,
enc_hidden_size=hidden_size,
dec_hidden_size=hidden_size,
dropout=dropout)
decoder = Decoder(vocab_size=cn_total_words,
embed_size=embed_size,
enc_hidden_size=hidden_size,
dec_hidden_size=hidden_size,
dropout=dropout)
model = Seq2Seq(encoder, decoder)
model = model.to(device)
loss_fn = LanguageModelCriterion().to(device)
optimizer = torch.optim.Adam(model.parameters())
1
train(model, train_data, num_epochs=30)
Epoch 0 iteration 0 loss 8.078022003173828
Epoch 0 iteration 100 loss 5.414377689361572
Epoch 0 iteration 200 loss 4.643333435058594
Epoch 0 Training loss 5.485134587536152
Evaluation loss 5.067514630874862
Epoch 1 iteration 0 loss 4.940210342407227
Epoch 1 iteration 100 loss 4.9903435707092285
Epoch 1 iteration 200 loss 4.186498641967773
Epoch 1 Training loss 4.877356682952294
Epoch 2 iteration 0 loss 4.509239196777344
Epoch 2 iteration 100 loss 4.570853233337402
Epoch 2 iteration 200 loss 3.7934508323669434
Epoch 2 Training loss 4.453642889638262
Epoch 3 iteration 0 loss 4.11014986038208
Epoch 3 iteration 100 loss 4.230580806732178
Epoch 3 iteration 200 loss 3.4451844692230225
Epoch 3 Training loss 4.105205834096106
Epoch 4 iteration 0 loss 3.788179397583008
Epoch 4 iteration 100 loss 3.984476089477539
Epoch 4 iteration 200 loss 3.205059289932251
Epoch 4 Training loss 3.8313639103406314
Epoch 5 iteration 0 loss 3.572876214981079
Epoch 5 iteration 100 loss 3.7907521724700928
Epoch 5 iteration 200 loss 3.0604655742645264
Epoch 5 Training loss 3.61275750220716
Evaluation loss 3.6225900108158475
Epoch 6 iteration 0 loss 3.331376552581787
Epoch 6 iteration 100 loss 3.607234239578247
Epoch 6 iteration 200 loss 2.8438034057617188
Epoch 6 Training loss 3.4240881394610914
Epoch 7 iteration 0 loss 3.1553823947906494
Epoch 7 iteration 100 loss 3.4283368587493896
Epoch 7 iteration 200 loss 2.679870367050171
Epoch 7 Training loss 3.2619650765874195
Epoch 8 iteration 0 loss 3.0175576210021973
Epoch 8 iteration 100 loss 3.313087224960327
Epoch 8 iteration 200 loss 2.573970079421997
Epoch 8 Training loss 3.119750910546451
Epoch 9 iteration 0 loss 2.8687644004821777
Epoch 9 iteration 100 loss 3.2016961574554443
Epoch 9 iteration 200 loss 2.4501001834869385
Epoch 9 Training loss 2.9937007481445184
Epoch 10 iteration 0 loss 2.7964212894439697
Epoch 10 iteration 100 loss 3.094231128692627
Epoch 10 iteration 200 loss 2.2865397930145264
Epoch 10 Training loss 2.879919764606877
Evaluation loss 3.164760209368642
Epoch 11 iteration 0 loss 2.6683473587036133
Epoch 11 iteration 100 loss 3.008727788925171
Epoch 11 iteration 200 loss 2.1880834102630615
Epoch 11 Training loss 2.7794466071573467
Epoch 12 iteration 0 loss 2.5640454292297363
Epoch 12 iteration 100 loss 2.896376132965088
Epoch 12 iteration 200 loss 2.1036128997802734
Epoch 12 Training loss 2.684113484535982
Epoch 13 iteration 0 loss 2.520007371902466
Epoch 13 iteration 100 loss 2.8189423084259033
Epoch 13 iteration 200 loss 2.0698890686035156
Epoch 13 Training loss 2.5990255668547055
Epoch 14 iteration 0 loss 2.42832612991333
Epoch 14 iteration 100 loss 2.7819204330444336
Epoch 14 iteration 200 loss 1.923954725265503
Epoch 14 Training loss 2.5176252404633574
Epoch 15 iteration 0 loss 2.360988140106201
Epoch 15 iteration 100 loss 2.6843974590301514
Epoch 15 iteration 200 loss 1.912152886390686
Epoch 15 Training loss 2.4463321701504275
Evaluation loss 2.9698491313827047
Epoch 16 iteration 0 loss 2.2877912521362305
Epoch 16 iteration 100 loss 2.6055469512939453
Epoch 16 iteration 200 loss 1.8231658935546875
Epoch 16 Training loss 2.3756549535366713
Epoch 17 iteration 0 loss 2.191697597503662
Epoch 17 iteration 100 loss 2.5865063667297363
Epoch 17 iteration 200 loss 1.7817124128341675
Epoch 17 Training loss 2.313343924902058
Epoch 18 iteration 0 loss 2.1245803833007812
Epoch 18 iteration 100 loss 2.525496482849121
Epoch 18 iteration 200 loss 1.672200322151184
Epoch 18 Training loss 2.2498218108556114
Epoch 19 iteration 0 loss 2.06477427482605
Epoch 19 iteration 100 loss 2.443316698074341
Epoch 19 iteration 200 loss 1.6326298713684082
Epoch 19 Training loss 2.19988960411091
Epoch 20 iteration 0 loss 2.0234487056732178
Epoch 20 iteration 100 loss 2.416968822479248
Epoch 20 iteration 200 loss 1.583616852760315
Epoch 20 Training loss 2.1513965044733827
Evaluation loss 2.8699020465835643
Epoch 21 iteration 0 loss 2.008730411529541
Epoch 21 iteration 100 loss 2.3642444610595703
Epoch 21 iteration 200 loss 1.5385680198669434
Epoch 21 Training loss 2.098746986360735
Epoch 22 iteration 0 loss 1.910429835319519
Epoch 22 iteration 100 loss 2.339489459991455
Epoch 22 iteration 200 loss 1.4784246683120728
Epoch 22 Training loss 2.051404798098097
Epoch 23 iteration 0 loss 1.8959044218063354
Epoch 23 iteration 100 loss 2.2653536796569824
Epoch 23 iteration 200 loss 1.4792706966400146
Epoch 23 Training loss 2.00636701965731
Epoch 24 iteration 0 loss 1.8477107286453247
Epoch 24 iteration 100 loss 2.1904118061065674
Epoch 24 iteration 200 loss 1.3925689458847046
Epoch 24 Training loss 1.965628425139225
Epoch 25 iteration 0 loss 1.7790645360946655
Epoch 25 iteration 100 loss 2.182420492172241
Epoch 25 iteration 200 loss 1.3576843738555908
Epoch 25 Training loss 1.9238889035465652
Evaluation loss 2.826008448512912
Epoch 26 iteration 0 loss 1.73543381690979
Epoch 26 iteration 100 loss 2.1740329265594482
Epoch 26 iteration 200 loss 1.328704595565796
Epoch 26 Training loss 1.889945533318946
Epoch 27 iteration 0 loss 1.7498269081115723
Epoch 27 iteration 100 loss 2.1384894847869873
Epoch 27 iteration 200 loss 1.277467966079712
Epoch 27 Training loss 1.852515173441663
Epoch 28 iteration 0 loss 1.6980342864990234
Epoch 28 iteration 100 loss 2.1195883750915527
Epoch 28 iteration 200 loss 1.2595137357711792
Epoch 28 Training loss 1.8210893462516964
Epoch 29 iteration 0 loss 1.6773594617843628
Epoch 29 iteration 100 loss 2.0760860443115234
Epoch 29 iteration 200 loss 1.2345834970474243
Epoch 29 Training loss 1.7873437400435428
1
2
3
for i in range(100,120):
translate_dev(i)
print()
BOS you have nice skin . EOS
BOS 你 的 皮 膚 真 好 。 EOS
你好害怕。

BOS you 're UNK correct . EOS
BOS 你 部 分 正 确 。 EOS
你是全子的声音。

BOS everyone admired his courage . EOS
BOS 每 個 人 都 佩 服 他 的 勇 氣 。 EOS
他的袋子是他的勇氣。

BOS what time is it ? EOS
BOS 几 点 了 ? EOS
多少时间是什么?

BOS i 'm free tonight . EOS
BOS 我 今 晚 有 空 。 EOS
我今晚有空。

BOS here is your book . EOS
BOS 這 是 你 的 書 。 EOS
这儿是你的书。

BOS they are at lunch . EOS
BOS 他 们 在 吃 午 饭 。 EOS
他们在午餐。

BOS this chair is UNK . EOS
BOS 這 把 椅 子 很 UNK 。 EOS
這些花一下是正在的。

BOS it 's pretty heavy . EOS
BOS 它 真 重 。 EOS
它很美的脚。

BOS many attended his funeral . EOS
BOS 很 多 人 都 参 加 了 他 的 葬 礼 。 EOS
多多衛年轻地了他。

BOS training will be provided . EOS
BOS 会 有 训 练 。 EOS
别将被付錢。

BOS someone is watching you . EOS
BOS 有 人 在 看 著 你 。 EOS
有人看你。

BOS i slapped his face . EOS
BOS 我 摑 了 他 的 臉 。 EOS
我把他的臉抱歉。

BOS i like UNK music . EOS
BOS 我 喜 歡 流 行 音 樂 。 EOS
我喜歡音樂。

BOS tom had no children . EOS
BOS T o m 沒 有 孩 子 。 EOS
汤姆没有照顧孩子。

BOS please lock the door . EOS
BOS 請 把 門 鎖 上 。 EOS
请把門開門。

BOS tom has calmed down . EOS
BOS 汤 姆 冷 静 下 来 了 。 EOS
汤姆在做了。

BOS please speak more loudly . EOS
BOS 請 說 大 聲 一 點 兒 。 EOS
請說更多。

BOS keep next sunday free . EOS
BOS 把 下 周 日 空 出 来 。 EOS
繼續下週一下一步。

BOS i made a mistake . EOS
BOS 我 犯 了 一 個 錯 。 EOS
我做了一件事。